import torch.nn as nn
import torch.nn.init as init
import numpy as np
import torch

def mixup_data(x, 
               y, 
               alpha=1.0, 
               device='cpu',
               append_noise=True,
               noise_type='gaussian',
               class_conditional=True,
               num_classes=10):

    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0

    batch_size = x.shape[0]

    if append_noise:
        if noise_type.lower() == 'uni' or noise_type.lower() == 'uniform':
            # Append Uniform noise
            noise_size = (int(x.shape[0]/num_classes), 
                          x.shape[1],
                          x.shape[2], 
                          x.shape[3])

            noise = torch.rand(noise_size).to(device)
        
        elif noise_type.lower() == 'gauss' or noise_type.lower() == 'gaussian':
            num_image_channels = x.shape[1]
            noise = torch.tensor([]).to(device)
            for channel in range(num_image_channels):
                # Per channel mean
                mean = torch.mean(x[:, channel])

                # Per channel std
                std = torch.std(x[:, channel])
                noise_size = (int(x.shape[0] / num_classes), 
                              1,
                              x.shape[2], 
                              x.shape[3])
                              
                channel_noise = torch.normal(mean=mean, 
                                             std=std, 
                                             size=noise_size).to(device)

                noise = torch.cat([noise, channel_noise], dim=1)
   
        noise_labels = num_classes * torch.ones(int(x.shape[0] / num_classes), 
                                                dtype=torch.long).to(device)

        y = torch.cat([y, noise_labels], dim=0)
        x = torch.cat([x, noise], dim=0).to(device)

    if not class_conditional:
        index = torch.randperm(batch_size).to(device)
        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]

    else:
        # Same class mixup with noise appended
        n = 2 # Number of images within the same class to mix
        mixed_x = x.detach().clone() 
        for i in range(num_classes):
            index = torch.nonzero(y == i).squeeze(-1).to(device)

            for j in range(n-1):
                index_shuffle = index[torch.randperm(index.size()[0])]
                x[index] = (lam * x[index] +(1-lam)* mixed_x[index_shuffle])

        mixed_x = x
        y_a, y_b = y, y
                
    return mixed_x, y_a, y_b, lam

def soft_cross_entropy_loss(logits, y_a, y_b, lam, num_classes):
    y_a_onehot = torch.nn.functional.one_hot(y_a, num_classes)
    y_b_onehot = torch.nn.functional.one_hot(y_b, num_classes)
    soft_labels = y_a_onehot * lam + y_b_onehot * (1 - lam)
    prob = torch.nn.functional.softmax(logits, dim=1)
    loss = -soft_labels * torch.log(prob)
    loss = loss.sum() / logits.shape[0]
    return loss